from typing import Dict, List

import pymysql
import copy
from pymysql.cursors import Cursor
from pymysql.connections import Connection
from dbutils.pooled_db import PooledDB
from urllib.parse import urlparse

from common.logger import logger


# @singleton
class MysqlPool(object):
    """数据库连接池，不能是单例"""

    def __init__(self, url, **kwargs):
        self.url = url
        if not url:
            return
        url = urlparse(url)
        self.POOL: PooledDB = PooledDB(
            creator=pymysql,
            maxconnections=10,  # 连接池的最大连接数
            maxcached=10,
            maxshared=10,
            blocking=True,
            setsession=[],
            host=url.hostname,
            port=url.port or 3306,
            user=url.username,
            password=url.password,
            database=url.path.strip().strip('/'),
            charset='utf8',
            **kwargs
        )
        self.conn: Connection = self.POOL.connection()
        self.cursor: Cursor = self.conn.cursor()


    def __del__(self):
        self.cursor.close()
        self.conn.close()
        self.POOL.close()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.POOL.close()

    def connect(self):
        conn: Connection = self.POOL.connection()
        cursor: Cursor = conn.cursor()
        return conn, cursor

    def close(self):
        self.cursor.close()
        self.conn.close()
        self.POOL.close()

    # 关闭连接归还给链接池
    def connect_close(self, conn, cursor):
        cursor.close()
        conn.close()

    def fetch_one(self, sql, args, columns) -> dict:
        logger.debug(sql)
        conn, cursor = self.connect()
        cursor.execute(sql, args)
        result = cursor.fetchone()
        self.connect_close(cursor, conn)
        return dict(zip(columns, result)) if result is not None else None

    def fetch_all(self, sql, args=None, columns: list = None) -> List[Dict]:
        logger.debug(sql)
        conn, cursor = self.connect()
        if args is None:
            cursor.execute(sql)
        else:
            cursor.execute(sql, args)
        record_list = cursor.fetchall()
        self.connect_close(cursor, conn)
        return [dict(zip(columns, row)) for row in record_list]

    def select_ones(self, sql, *params) -> set:
        logger.debug(sql)
        # 单一列不重复数据
        self.cursor.execute(sql, *params)
        return {row[0] for row in self.cursor.fetchall()}

    def insert(self, table, columns, data, auto_commit=False):
        sql = f'insert into {table}({",".join(columns)}) values ({"%s," * len(columns)}'.rstrip(",") + ')'
        logger.debug(sql)
        if auto_commit:
            conn, cursor = self.connect()
            row = cursor.execute(sql, data)
            conn.commit()
            self.connect_close(conn, cursor)
        else:
            row = self.cursor.execute(sql, data)
        return row

    def insert_batch(self, table, columns, data: list, auto_commit=False):
        sql = f"""insert into {table}({",".join(columns)}) values ({"%s," * len(columns)}""".rstrip(",") + ')'
        logger.debug(sql)
        if auto_commit:
            conn, cursor = self.connect()
            row = cursor.executemany(sql, data)
            conn.commit()
            self.connect_close(conn, cursor)
        else:
            row = self.cursor.executemany(sql, data)
        return row

    def insertById(self, table: str, columns: list, data: list[tuple], key: str):
        ids = [str(i[0]) for i in data]
        select = f"select {key} from {table} where {key} in ({','.join(ids)})"
        keys = self.select_ones(select)
        sql = f'insert into {table}({",".join(columns)}) values ({"%s," * len(columns)}'.rstrip(",") + ')'
        logger.info(sql)
        self.cursor.fast_executemany = True
        extra_data = [i for i in data if i[0] not in keys]
        if len(extra_data) > 0:
            num = self.cursor.executemany(sql, extra_data)
            logger.info(f"row: {num}")

    def insertOrUpdateById(self, table: str, columns: list, data: list[tuple], key: str = None, auto_commit=False):
        if len(data) == 0:
            return
        sql = f'insert into {table}({",".join(columns)}) values ({"%s," * len(columns)}'.rstrip(",") + ')' \
              + f"on duplicate key update " + ",".join([f"{i}=values({i})" for i in columns])
        logger.info(sql)
        self.cursor.fast_executemany = True
        try:
            if auto_commit:
                conn, cursor = self.connect()
                cursor.fast_executemany = True
                row = cursor.executemany(sql, data)
                conn.commit()
                self.connect_close(conn, cursor)
            else:
                row = self.cursor.executemany(sql, data)
            logger.info(f"row: {row}")
            return row
        except Exception as e:
            logger.error(sql)
            raise e

    def insertOrUpdateByKey(self, table: str, columns: list, data: list[dict], key: str, autocommit=False):
        ids = [str(i.get(key)) for i in data]
        select = f"select {key} from {table} where {key} in ({('%s,' * len(ids)).rstrip(',')}) group by {key}"
        rows = self.fetch_all(select, ids, columns=[key])
        tmp_ids = [i.get(key) for i in rows]
        sql = f'insert into {table}({",".join(columns)}) values ({"%s," * len(columns)}'.rstrip(",") + ')'
        logger.info(sql)
        self.cursor.fast_executemany = True
        save_data = [i for i in data if i.get(key) not in tmp_ids]
        update_data = [i for i in data if i.get(key) in tmp_ids]
        row = None
        if len(save_data) > 0:
            if autocommit:
                conn, cursor = self.connect()
                cursor.fast_executemany = True
                row = cursor.executemany(sql, [[i.get(c) for c in columns] for i in save_data])
                conn.commit()
                self.connect_close(conn, cursor)
            else:
                row = self.cursor.executemany(sql, [[i.get(c) for c in columns] for i in save_data])
        logger.info(f"row: {row}")
        if len(update_data) > 0:
            if key in columns:
                columns = copy.deepcopy(columns)
                columns.pop(columns.index(key))
            self.updateBatchById(table, columns, data, key, autocommit)
        return row

    def update(self, sql, arge=None):
        logger.debug(sql)
        return self.cursor.execute(sql, arge)

    def updateBatchById(self, table, columns, data: list[dict], key, autocommit=False):
        sql = f"update {table} set {','.join([f' {i} = %s ' for i in columns])} where {key} = %s"
        logger.info(sql)
        if autocommit:
            conn, cursor = self.connect()
            cursor.fast_executemany = True
            row = cursor.executemany(sql, [[i.get(c) for c in columns] + [i.get(key)] for i in data])
            conn.commit()
            self.connect_close(conn, cursor)
        else:
            row = self.cursor.executemany(sql, [[i.get(c) for c in columns] + [i.get(key)] for i in data])
        logger.info(f"row: {row}")
        return row

    def execute(self, sql: str, arge=None):
        logger.debug(sql)
        self.cursor.execute(sql, arge)
        return self.cursor.fetchall()

    def clear(self, table: str, auto_commit=False):
        # 清空表
        sql = f'truncate table {table}'
        logger.debug(sql)
        self.cursor.execute(sql)
        if auto_commit:
            self.conn.commit()

    def commit(self):
        self.conn.commit()
        logger.debug(f"commit: {self.cursor}")

    def rollback(self):
        self.conn.rollback()
        logger.debug(f"rollback: {self.cursor}")


if __name__ == "__main__":
    pool = MysqlPool(url="mysql://root:wt123456@47.98.98.230:33006/lkm?charset=utf8")
    # use = { 3475844845: {'user_id': 'j5pwk54', 'account': 'maomumu@x1star.cn', 'passwd': 'HZtanGTangle561...',
    # 'phone': "13296754557"}, 3373916720: {'user_id': 'j5pwk53', 'account': 'shijiamin@x1star.cn', 'passwd':
    # 'louxia897', 'phone': "17682342790"}, 3218879693: {'user_id': 'j5pwk52', 'account': 'xx@x1star.cn',
    # 'passwd': 'HZxINyP580@', 'phone': "16621070454"}, 3022740200: {'user_id': 'j5pwk51', 'account':
    # 'yuyin@sugarfree.wang', 'passwd': 'HzWuTang482...', 'phone': "19357592936"}, 2826790055: {'user_id': 'j5pwk50',
    # 'account': 'enjia@x1star.cn', 'passwd': 'FZqqm891!...', 'phone': "15267987357"}, 2235620049: {'user_id':
    # 'j5pwk4y', 'account': 'kiki@x1star.cn', 'passwd': 'HZxingYi371@', 'phone': "13075988025"}, 2193650044: {
    # 'user_id': 'j5pwk4x', 'account': 'wtang2021@163.com', 'passwd': 'SHwuTang658@', 'phone': "15618956146"},
    # 1721440135: {'user_id': 'j5pwk4w', 'account': '洛客猫数字科技', 'passwd': 'HzluoKeMao549@', 'phone': "19357550362"},
    # 11114923: {'user_id': 'j5g8377', 'account': '897215122@qq.com', 'passwd': 'LuoKmao897@', 'phone': "0"} }
    columns = ["id", "user_id", "account", "passwd", "phone"]
    # data = []
    # for k,v in use.items():
    #     v["id"] = k
    #     # for _,v2 in v.items():
    #     #     v[_] = str(v2)
    #     data.append(v)
    # pool.insert_batch("tb_alliance_user", columns, parser_data(data, columns))
    # data = pool.fetch_all("select * from tb_alliance_user",None,columns)
    # print(data)
    print(pool.fetch_one("select * from tb_alliance_user where id = %s", 3373916720, columns))
